# -*- coding: utf-8 -*-
#

from django.db.models import F, Count
from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers

from accounts.models import Account
from accounts.serializers import AccountSerializer
from common.const import UUID_PATTERN
from common.serializers import (
    WritableNestedModelSerializer, SecretReadableMixin,
    CommonModelSerializer, MethodSerializer, ResourceLabelsMixin
)
from common.serializers.common import DictSerializer
from common.serializers.fields import LabeledChoiceField, ObjectRelatedField
from labels.models import Label
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ...const import Category, AllTypes
from ...models import Asset, Node, Platform, Protocol, Host, Device, Database, Cloud, Web, Custom

__all__ = [
    'AssetSerializer', 'AssetSimpleSerializer', 'MiniAssetSerializer',
    'AssetTaskSerializer', 'AssetsTaskSerializer', 'AssetProtocolsSerializer',
    'AssetDetailSerializer', 'DetailMixin', 'AssetAccountSerializer',
    'AccountSecretSerializer', 'AssetProtocolsPermsSerializer', 'AssetLabelSerializer'
]


class AssetProtocolsSerializer(serializers.ModelSerializer):
    port = serializers.IntegerField(required=False, allow_null=True, max_value=65535, min_value=0)

    def get_render_help_text(self):
        if self.parent and self.parent.many:
            return _('Protocols, format is ["protocol/port"]')
        else:
            return _('Protocol, format is name/port')

    def to_file_representation(self, data):
        return '{name}/{port}'.format(**data)

    def to_file_internal_value(self, data):
        name, port = data.split('/')
        return {'name': name, 'port': port}

    class Meta:
        model = Protocol
        fields = ['name', 'port']


class AssetProtocolsPermsSerializer(AssetProtocolsSerializer):
    class Meta(AssetProtocolsSerializer.Meta):
        fields = AssetProtocolsSerializer.Meta.fields + ['public', 'setting']


class AssetLabelSerializer(serializers.ModelSerializer):
    class Meta:
        model = Label
        fields = ['id', 'name', 'value']
        extra_kwargs = {
            # 取消默认唯一键的校验
            'id': {'validators': []},
            'name': {'required': False},
            'value': {'required': False},
        }


class AssetPlatformSerializer(serializers.ModelSerializer):
    class Meta:
        model = Platform
        fields = ['id', 'name']
        extra_kwargs = {
            'name': {'required': False}
        }


class AssetAccountSerializer(AccountSerializer):
    add_org_fields = False
    asset = serializers.PrimaryKeyRelatedField(queryset=Asset.objects, required=False, write_only=True)
    clone_id = None

    def to_internal_value(self, data):
        # 导入时，data有时为str
        if isinstance(data, str):
            return super().to_internal_value(data)

        clone_id = data.pop('id', None)
        ret = super().to_internal_value(data)
        self.clone_id = clone_id
        return ret

    def set_secret(self, attrs):
        _id = self.clone_id
        if not _id:
            return attrs

        account = Account.objects.get(id=_id)
        attrs['secret'] = account.secret
        return attrs

    def validate(self, attrs):
        attrs = super().validate(attrs)
        return self.set_secret(attrs)

    def get_render_help_text(self):
        return _('Accounts, format [{"name": "x", "username": "x", "secret": "x", "secret_type": "password"}]')

    class Meta(AccountSerializer.Meta):
        fields = [
            f for f in AccountSerializer.Meta.fields
            if f not in [
                'spec_info', 'connectivity', 'labels', 'created_by',
                'date_update', 'date_created'
            ]
        ]
        extra_kwargs = {
            **AccountSerializer.Meta.extra_kwargs,
        }


class AccountSecretSerializer(SecretReadableMixin, CommonModelSerializer):
    class Meta:
        model = Account
        fields = [
            'name', 'username', 'privileged', 'secret_type', 'secret',
        ]
        extra_kwargs = {
            'secret': {'write_only': False},
        }


class NodeDisplaySerializer(serializers.ListField):
    def get_render_help_text(self):
        return _('Node path, format ["/org_name/node_name"], if node not exist, will create it')

    def to_internal_value(self, data):
        return data

    def to_representation(self, data):
        return data


class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, WritableNestedModelSerializer):
    category = LabeledChoiceField(choices=Category.choices, read_only=True, label=_('Category'))
    type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type'))
    protocols = AssetProtocolsSerializer(many=True, required=False, label=_('Protocols'), default=())
    accounts = AssetAccountSerializer(many=True, required=False, allow_null=True, write_only=True, label=_('Accounts'))
    nodes_display = NodeDisplaySerializer(read_only=False, required=False, label=_("Node path"))
    platform = ObjectRelatedField(queryset=Platform.objects, required=True, label=_('Platform'),
                                  attrs=('id', 'name', 'type'))
    accounts_amount = serializers.IntegerField(read_only=True, label=_('Accounts amount'))
    _accounts = None

    class Meta:
        model = Asset
        fields_fk = ['zone', 'platform']
        fields_mini = ['id', 'name', 'address'] + fields_fk
        fields_small = fields_mini + ['is_active', 'comment']
        fields_m2m = [
            'nodes', 'labels', 'protocols',
            'nodes_display', 'accounts',
            'directory_services',
        ]
        read_only_fields = [
            'accounts_amount', 'category', 'type', 'connectivity', 'auto_config',
            'date_verified', 'created_by', 'date_created', 'date_updated',
        ]
        fields = fields_small + fields_fk + fields_m2m + read_only_fields
        fields_unexport = ['auto_config']
        extra_kwargs = {
            'auto_config': {'label': _('Auto info')},
            'name': {'label': _("Name"), 'initial': 'Asset name'},
            'address': {'label': _('Address')},
            'nodes_display': {'label': _('Node path')},
            'nodes': {'allow_empty': True, 'label': _("Nodes")},
            'directory_services': {
                'required': False,
                'allow_empty': True,
                'default': list, 'label': _("Directory service")
            },
        }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._init_field_choices()
        self._extract_accounts()

    def _extract_accounts(self):
        if not getattr(self, 'initial_data', None):
            return
        if isinstance(self.initial_data, list):
            return
        accounts = self.initial_data.pop('accounts', None)
        self._accounts = accounts

    def _get_protocols_required_default(self):
        platform = self._asset_platform
        platform_protocols = platform.protocols.all()
        protocols_default = [p for p in platform_protocols if p.default]
        protocols_required = [p for p in platform_protocols if p.required or p.primary]
        return protocols_required, protocols_default

    def _set_protocols_default(self):
        if not hasattr(self, 'initial_data'):
            return
        protocols = self.initial_data.get('protocols')
        if protocols is not None:
            return
        if getattr(self, 'instance', None):
            return

        protocols_required, protocols_default = self._get_protocols_required_default()
        protocol_map = {str(protocol.id): protocol for protocol in protocols_required + protocols_default}
        protocols = list(protocol_map.values())
        protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
        self.initial_data['protocols'] = protocols_data

    def _init_field_choices(self):
        request = self.context.get('request')
        if not request:
            return
        category = request.path.strip('/').split('/')[-1].rstrip('s')
        field_category = self.fields.get('category')
        if not field_category:
            return
        field_category.choices = Category.filter_choices(category)
        field_type = self.fields.get('type')
        if not field_type:
            return
        field_type.choices = AllTypes.filter_choices(category)

    @classmethod
    def setup_eager_loading(cls, queryset):
        """ Perform necessary eager loading of data. """
        queryset = queryset.prefetch_related('zone', 'nodes', 'protocols', 'directory_services') \
            .prefetch_related('platform', 'platform__automation') \
            .annotate(category=F("platform__category")) \
            .annotate(type=F("platform__type")) \
            .annotate(accounts_amount=Count('accounts'))
        return queryset

    @staticmethod
    def perform_nodes_display_create(instance, nodes_display):
        if not nodes_display:
            return
        nodes_to_set = []
        for full_value in nodes_display:
            if not full_value.startswith('/'):
                full_value = '/' + instance.org.name + '/' + full_value
            node = Node.objects.filter(full_value=full_value).first()
            if node:
                nodes_to_set.append(node)
            else:
                node = Node.create_node_by_full_value(full_value)
            nodes_to_set.append(node)
        instance.nodes.set(nodes_to_set)

    @property
    def _asset_platform(self):
        platform_id = self.initial_data.get('platform')
        if isinstance(platform_id, dict):
            platform_id = platform_id.get('id') or platform_id.get('pk')

        if not platform_id and self.instance:
            platform = self.instance.platform
        else:
            platform = Platform.objects.filter(id=platform_id).first()

        if not platform:
            raise serializers.ValidationError({'platform': _("Platform not exist")})
        return platform

    def validate_zone(self, value):
        platform = self._asset_platform
        if platform.gateway_enabled:
            return value
        else:
            return None

    def validate_nodes(self, nodes):
        if nodes:
            return nodes
        nodes_display = self.initial_data.get('nodes_display')
        if nodes_display:
            return nodes
        default_node = Node.org_root()
        request = self.context.get('request')
        if not request:
            return [default_node]
        node_id = request.query_params.get('node_id')
        if not node_id:
            return [default_node]
        nodes = Node.objects.filter(id=node_id)
        return nodes

    def is_valid(self, raise_exception=False):
        self._set_protocols_default()
        return super().is_valid(raise_exception=raise_exception)

    def validate_protocols(self, protocols_data):
        # 目的是去重
        protocols_data_map = {p['name']: p for p in protocols_data}
        for p in protocols_data:
            port = p.get('port', 0)
            if port < 0 or port > 65535:
                error = p.get('name') + ': ' + _("port out of range (0-65535)")
                raise serializers.ValidationError(error)

        protocols_required, __ = self._get_protocols_required_default()
        protocols_not_found = [p.name for p in protocols_required if p.name not in protocols_data_map]
        if protocols_not_found:
            raise serializers.ValidationError({
                'protocols': _("Protocol is required: {}").format(', '.join(protocols_not_found))
            })
        return protocols_data_map.values()

    def validate_platform(self, platform_data):
        check_models = {Host, Device, Database, Cloud, Web, Custom}
        if self.Meta.model not in check_models:
            return platform_data
        model_name = self.Meta.model.__name__.lower()
        if model_name != platform_data.category:
            raise serializers.ValidationError({
                'platform': f"Platform does not match: {platform_data.name}"
            })
        return platform_data

    @staticmethod
    def update_account_su_from(accounts, include_su_from_accounts):
        if not include_su_from_accounts:
            return
        name_map = {account.name: account for account in accounts}
        username_secret_type_map = {
            (account.username, account.secret_type): account for account in accounts
        }

        for name, username_secret_type in include_su_from_accounts.items():
            account = name_map.get(name)
            if not account:
                continue
            su_from_account = username_secret_type_map.get(username_secret_type)
            if su_from_account:
                account.su_from = su_from_account
                account.save()

    def accounts_create(self, accounts_data, asset):
        from accounts.models import AccountTemplate
        if not accounts_data:
            return

        if not isinstance(accounts_data[0], dict):
            raise serializers.ValidationError({'accounts': _("Invalid data")})

        su_from_name_username_secret_type_map = {}
        for data in accounts_data:
            data['asset'] = asset.id
            name = data.get('name')
            su_from = data.pop('su_from', None)
            template_id = data.get('template', None)
            if template_id:
                template = AccountTemplate.objects.get(id=template_id)
                template.push_params = data.pop('push_params', {})
                data['params'] = template.push_params
                if template.su_from:
                    su_from_name_username_secret_type_map[template.name] = (
                        template.su_from.username, template.su_from.secret_type
                    )
            elif isinstance(su_from, dict):
                su_from = Account.objects.get(id=su_from.get('id'))
                su_from_name_username_secret_type_map[name] = (
                    su_from.username, su_from.secret_type
                )
        s = AssetAccountSerializer(data=accounts_data, many=True)
        s.is_valid(raise_exception=True)
        accounts = s.save()
        self.update_account_su_from(accounts, su_from_name_username_secret_type_map)

    @atomic
    def create(self, validated_data):
        nodes_display = validated_data.pop('nodes_display', '')
        instance = super().create(validated_data)
        self.accounts_create(self._accounts, instance)
        self.perform_nodes_display_create(instance, nodes_display)
        return instance

    @staticmethod
    def sync_platform_protocols(instance, old_platform):
        platform = instance.platform

        if str(old_platform.id) == str(instance.platform_id):
            return

        platform_protocols = {
            p['name']: p['port']
            for p in platform.protocols.values('name', 'port')
        }

        protocols = set(instance.protocols.values_list('name', flat=True))
        protocol_names = set(platform_protocols) - protocols
        objs = []
        for name in protocol_names:
            objs.append(
                Protocol(
                    name=name,
                    port=platform_protocols[name],
                    asset_id=instance.id,
                )
            )
        Protocol.objects.bulk_create(objs)

    @atomic
    def update(self, instance, validated_data):
        old_platform = instance.platform
        nodes_display = validated_data.pop('nodes_display', '')
        instance = super().update(instance, validated_data)
        self.sync_platform_protocols(instance, old_platform)
        self.perform_nodes_display_create(instance, nodes_display)
        return instance


class DetailMixin(serializers.Serializer):
    accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
    spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
    gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
    auto_config = serializers.DictField(read_only=True, label=_('Auto info'))

    def get_instance(self):
        request = self.context.get('request')
        if not self.instance and UUID_PATTERN.findall(request.path):
            pk = UUID_PATTERN.findall(request.path)[0]
            self.instance = Asset.objects.filter(id=pk).first()
        return self.instance

    def get_field_names(self, declared_fields, info):
        names = super().get_field_names(declared_fields, info)
        names.extend([
            'accounts', 'gathered_info', 'spec_info', 'auto_config',
        ])
        return names

    def get_category(self):
        request = self.context.get('request')
        if request.query_params.get('category'):
            category = request.query_params.get('category')
        else:
            instance = self.get_instance()
            category = instance.category if instance else 'host'
        return category

    def get_gathered_info_serializer(self):
        category = self.get_category()
        from .info.gathered import category_gathered_serializer_map
        serializer_cls = category_gathered_serializer_map.get(category, DictSerializer)
        return serializer_cls()

    def get_spec_info_serializer(self):
        category = self.get_category()
        from .info.spec import category_spec_serializer_map
        serializer_cls = category_spec_serializer_map.get(category, DictSerializer)
        return serializer_cls()


class AssetDetailSerializer(DetailMixin, AssetSerializer):
    pass


class MiniAssetSerializer(serializers.ModelSerializer):
    class Meta:
        model = Asset
        fields = AssetSerializer.Meta.fields_mini


class AssetSimpleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Asset
        fields = [
            'id', 'name', 'address', 'port',
            'connectivity', 'date_verified'
        ]


class AssetsTaskSerializer(serializers.Serializer):
    ACTION_CHOICES = (
        ('refresh', 'refresh'),
        ('test', 'test'),
    )
    task = serializers.CharField(read_only=True)
    action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
    assets = serializers.PrimaryKeyRelatedField(
        queryset=Asset.objects, required=False, allow_empty=True, many=True
    )


class AssetTaskSerializer(AssetsTaskSerializer):
    ACTION_CHOICES = tuple(list(AssetsTaskSerializer.ACTION_CHOICES) + [
        ('push_system_user', 'push_system_user'),
        ('test_system_user', 'test_system_user')
    ])
    action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
    asset = serializers.PrimaryKeyRelatedField(
        queryset=Asset.objects, required=False, allow_empty=True, many=False
    )
    accounts = serializers.PrimaryKeyRelatedField(
        queryset=Account.objects, required=False, allow_empty=True, many=True
    )
